import torch
import torchvision.models as models

CHECKPOINT_URL = 'https://huggingface.co/alexsu52/mobilenet_v2_imagenette/resolve/main/pytorch_model.bin'
IMAGENETTE_CLASSES = 1000

model = models.mobilenet_v2(num_classes=IMAGENETTE_CLASSES)  
        
checkpoint = torch.hub.load_state_dict_from_url(CHECKPOINT_URL, progress=False)
model.load_state_dict(checkpoint['state_dict'])

# 保存整个 checkpoint（包括 'state_dict' 和其他可能的键）
SAVE_PATH = "mobilenetv2_imagenette_checkpoint.pth"  # 或 .pt 后缀
torch.save(checkpoint, SAVE_PATH)
print(f"Checkpoint 已保存到: {SAVE_PATH}")